import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
from statsmodels.tsa.filters.hp_filter import hpfilter
from linearmodels.panel import PanelOLS
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import patsy
from matplotlib.container import BarContainer

path = '/home/jbs/Documents/Research/china60_akkrs/v3_jpe_final/'
dtapath = path + '../20 Intermediate Files/'
inpath = path + 'model/output/'
outpath = path + 'scripts/output/'
calpath = path + 'model/input/'
figpath = path + 'scripts/figs/'
texpath = path + 'scripts/tex/'

#############################################
# plotting stuff

mpl.rc('savefig',bbox='tight')
mpl.rc('savefig',format='pdf')
mpl.rcParams['lines.markersize'] = 3

lw=3
tw=20
alpha=0.8
colors=['#377eb8','#e41a1c','#4daf4a','#984ea3', '#ff7f00','#a65628']#,'#ffff33']
#colors=['#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd','#8c564b','#e377c2','#7f7f7f','#bcbd22','#17becf']

def slide_fig():
    fig, ax = plt.subplots(figsize=(16, 10))
    ax.tick_params(axis='both', labelsize=18)
    ax.yaxis.label.set_size(18)
    sns.despine()
    return fig,ax

def paper_fig(zero_line, ntr_line, pntr_line, vote_line):
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.set_xlim(1974, 2008)
    sns.despine()
        
    if zero_line == True:
        ax.axhline(y=0, color='black', lw=0.75)
    if ntr_line == True:
        ax.axvline(x=1980, color='black', lw=0.75, linestyle=':')
    if pntr_line == True:
        ax.axvline(x=2001, color='black', lw=0.75, linestyle=':')
    if vote_line == True:
        ax.axvline(x=1990, color='black', lw=0.75, linestyle=':')

    return fig,ax

def extract_fig_data(ax,fname,ax2=None):
    
    lines = [l for l in ax.lines if len(l.get_ydata())>2]
    if(ax2 is not None):
        lines = lines + [l for l in ax2.lines if len(l.get_ydata())>2]
        
    x=lines[0].get_xdata()
    d={'xaxis':x}
    
    for l in lines:
        y=l.get_ydata()
        s=l.get_label()
        d[s]=y

    bars = [i for i in ax.containers if isinstance(i, BarContainer)]
    if(ax2 is not None):
        bars = bars + [i for i in ax2.containers if isinstance(i, BarContainer)]
    
    for b in bars:
        h = b.get_children()[0].get_height()
        s = b.get_label()
        d[s] = np.ones(len(x)) * h

    df = pd.DataFrame({ key:pd.Series(value) for key, value in d.items() })
    df.to_csv(fname, sep=',', index=False)

#############################################
# processing simulated data

N0 = 1974
NR = 1980
NU = 2001

NR2 = NR-N0
NU2 = NU-N0
N1990 = 1990-N0

agg_fns={'f':[('nf',lambda x:x.nunique())],
         'v':[('exports',lambda x:x.sum())]}

def reset_multiindex(df,n,suff):
    levels=df.columns.levels
    labels=df.columns.labels
    df.columns=levels[0][labels[0][0:n]].tolist()+[s+suff for s in levels[1][labels[1][n:]].tolist()]
    return df

def pct_chg(x):
        return (x/x.iloc[2])
    
def growth(x):
        return 100*(x/x.shift()-1.0)

def wavg(group, avg_name, weight_name):
    d = group[avg_name]
    w = group[weight_name]
    return (d * w).sum() / w.sum()

def load_data(reform_flag=0, suff='', suff2='_baseline'):
    
    fname=''
    if(reform_flag==0):
        fname = 'simul_agg_det0'+suff2+'.csv'
    elif(reform_flag==1):
        fname = 'simul_agg_det1'+suff2+'.csv'
    elif(reform_flag==3):
        fname = 'simul_agg_tpu'+suff2+'.csv' 

    try:
        data = pd.read_csv(inpath + fname)
    except:
        print('\tSimulation data not found at path ' + fname)
        return None
            
    data['y'] = data.y + 1971
    data['pre_2000'] = data.y<2000
    #data = data[(data.y>=1974) & (data.y<=2008)].reset_index(drop=True)
    
    data['tau_nntr'] = data.tau_nntr - 1
    data['tau_applied'] = data.tau_applied-1
    data.loc[data.y<1980,'tau_applied'] = data.loc[data.y<1980,'tau_nntr']

    d01 = data.loc[data.y==2001,:].reset_index(drop=True)
    d01['spread'] = np.log((1+d01.tau_nntr)/(1+d01.tau_applied))
    d01 = d01[['i','spread']].drop_duplicates()
    data = pd.merge(left=data,right=d01,how='left',on=['i'])
    
    data.sort_values(by=['i','y'],ascending=[True,True],inplace=True)
    data.reset_index(drop=True,inplace=True)

    data['nf_lag'] = data.groupby(['i'])['num_exporters']\
                         .transform(lambda x: x.shift())
    data['tau_lag'] = data.groupby(['i'])['tau_applied']\
                          .transform(lambda x: x.shift())
    data['tau_lead'] = data.groupby(['i'])['tau_applied']\
                           .transform(lambda x: x.shift(-1))

    data['exports_lag'] = data.groupby(['i'])['exports']\
                              .transform(lambda x: x.shift())

    data['delta_exports'] = np.log(data.exports) - np.log(data.exports_lag)
    data['delta_tau'] = np.log(1+data.tau_applied) - np.log(1+data.tau_lag)
    data['delta_nf'] = np.log(data.num_exporters) - np.log(data.nf_lag)

    data['extensive'] = data.num_exporters
    data['intensive'] = data.exports/data.num_exporters


    if 'quantity' in data.columns:
        data['unit_value'] = data.exports/data.quantity
    
    return data

#############################################
# regressions

def ecm_regression(df,suff2='_baseline'):
    SR_true = -2.23
    LR_true = -7.93
    df2 = df.loc[(df.exports>1e-8) & (df.exports_lag>1.0e-8),:].reset_index(drop=True)
    df2['tlag'] = np.log(1+df2.tau_lag)
    df2['elag'] = np.log(df2.exports_lag)
    df2['constant']=1
    df2.set_index(['i','y'],inplace=True)
    model = PanelOLS(dependent=df2['delta_exports'],exog=df2[['tlag','elag','delta_tau']],entity_effects=True,time_effects=False)
    eres3=model.fit()
    print("\tSR: %0.4f (%0.4f)\n\tLR: %0.4f (%0.4f)" % (eres3.params['delta_tau'],SR_true,-eres3.params['tlag']/eres3.params['elag'],LR_true))
    return (eres3.params['delta_tau']-SR_true), (-eres3.params['tlag']/eres3.params['elag']-LR_true)

def gap_regression_PS(df,return_se=False):
    df = df[(df.y>=1992)&(df.y<=2007)].reset_index(drop=True)
    df = df.loc[(df.exports>1e-8)]
    df['X'] = df.spread * df.pre_2000
    df['Y'] = np.log(df.exports)
    df['Z'] = np.log(1+df.tau_applied)
    df['constant']=1
    df.set_index(['i','y'],inplace=True)
    model = PanelOLS(dependent=df['Y'],exog=df[['constant','X','Z']],
                     entity_effects=True,time_effects=True)
    res=model.fit(cov_type='clustered',cluster_entity=True)
    beta = res.params['X']
    se = res.std_errors['X']

    if(return_se==True):
        return beta, se
    else:
        return beta

def gap_regression_AKKRS(df,col='exports',flag=0):
    df2 = df[(df.y>=1974)&(df.y<=2008)&(df[col]>1e-8)].reset_index(drop=True)
    dlast = df2.loc[df2.y==df2.y.max(),:].reset_index(drop=True)

    dlast.rename(columns={col:col+'_last'},inplace=True)
    dlast = dlast[['i',col+'_last']].drop_duplicates()

    df2 = pd.merge(left=df2,right=dlast,how='left',on=['i'])

    if(flag==0):
        df2[col+'2'] = df2[col]/df2[col+'_last']
    else:
        df2[col+'2'] = df2[col]
        
    col2 = col+'2'

    formula = 'np.log(%s) ~ np.log(%s) + y + i + C(y):spread' % ( (col2,col2) )

    junk,df3 = patsy.dmatrices(formula,df2,return_type="dataframe")
    df3.set_index(['i','y'],inplace=True)
    model = PanelOLS(dependent=df3['np.log(%s)'%col2],
                     exog=df3.loc[:,df3.columns != 'np.log(%s)'%col2],
                     entity_effects=False,time_effects=(True if flag==0 else False) )

    res=model.fit(cov_type='clustered',cluster_entity=True)

    #print(res)
    
    res1 = smf.ols(formula=formula,data=df2).fit(cov_type='HC0')

    years = df2.y.unique().tolist()
    effects = np.asarray([res.params['C(y)[%d]:spread'%y] for y in years])
    effects = np.asarray(effects)

    return effects

